import numpy as np
import torch
from torch.distributions import Normal
from torch.optim import Adam, RMSprop, SGD
import itertools
import matplotlib.pyplot as plt
from sklearn.metrics.pairwise import euclidean_distances
from pymoo.factory import get_performance_indicator
from pymoo.util.nds.non_dominated_sorting import NonDominatedSorting
import gym
import gymnasium
from gym import Wrapper
import d4rl
import dsrl
import os
import pickle
import time
import sys
from lib.utilities.hypervolume import InnerHyperVolume 
import lib.common_ptan as ptan
import torch.nn.functional as F
import scipy.stats
import moenvs
import copy

def L1_norm(x, epsilon=1e-6):
    if torch.is_tensor(x):
        x_norm = (x+epsilon) / (x.norm(p=1, dim=-1, keepdim=True)+epsilon)
    elif type(x) is np.ndarray:
        x_norm = (x+epsilon) / (np.linalg.norm(x, ord=1, axis=-1, keepdims=True)+epsilon)
    return x_norm

def compute_sparsity(obj_batch):
    non_dom = NonDominatedSorting().do(obj_batch, only_non_dominated_front=True)        
    objs = obj_batch[non_dom]
    sparsity_sum = 0     
    for objective in range(objs.shape[-1]):
        objs_sort = np.sort(objs[:,objective])
        sp = 0
        for i in range(len(objs_sort)-1):
            sp +=  np.power(objs_sort[i] - objs_sort[i+1],2)
        sparsity_sum += sp
    if len(objs) > 1:
        sparsity = sparsity_sum/(len(objs)-1)
    else:   
        sparsity = 0
    
    
    return sparsity

def reject_sample(dist, size):  # Ensure the samples > 0
    UPPER_REJECTED_NUM = 10
    legal_samples = None
    for i in range(UPPER_REJECTED_NUM):
        sample = dist.rsample((size*10,))
        accepted = sample[(sample>0).all(dim=1)]
        legal_samples = accepted if legal_samples is None else torch.cat([legal_samples, accepted], dim=0)
        if len(legal_samples)>=size: break
    assert i<UPPER_REJECTED_NUM, "Rejected sample number is over limit"
    return legal_samples[:size]

def generate_w_batch_test(args, step_size, reward_size=None):
    if reward_size is None: 
        reward_size = args.reward_size
    mesh_array = []
    step_size = step_size
    for i in range(reward_size):
        mesh_array.append(np.arange(0,1+step_size, step_size))
        
    w_batch_test = np.array(list(itertools.product(*mesh_array)))
    w_batch_test = w_batch_test[w_batch_test.sum(axis=1) == 1,:]
    w_batch_test = np.unique(w_batch_test,axis =0)
    
    return w_batch_test


def plot_pareto_front(args, objs, name, ext=''):
    plt.clf()
    all_labels = np.array([label[0] if hasattr(label, '__iter__') else label for label, _ in objs])
    vec_objs = np.array([obj for _, obj in objs])
    non_dom = NonDominatedSorting().do(-vec_objs, only_non_dominated_front=True)  
    dom = [i for i in range(len(vec_objs)) if not i in non_dom]
    if len(vec_objs[0])==2:
        fig=plt.figure(figsize=(7,5))
        s = [[20 if i in non_dom else 0 for i in range(len(vec_objs)) ]]
        plt.scatter(vec_objs[dom, 0], vec_objs[dom, 1], c=all_labels[dom], marker='x')
        sc = plt.scatter(vec_objs[:, 0], vec_objs[:, 1], c=all_labels[:], s=s, marker='o')
        plt.xlabel("objective 1")
        plt.ylabel("objective 2")
        plt.colorbar(sc)
    elif len(vec_objs[0])==3:
        fig=plt.figure(figsize=(20,20))
        s = [[10 if i in non_dom else 0 for i in range(len(vec_objs)) ]]
        ax1 = fig.add_subplot(1,1,1,projection = "3d")
        ax1.scatter(vec_objs[dom, 0], vec_objs[dom, 1], vec_objs[dom, 2], c=all_labels[dom], s=10, alpha=1.0, marker='x')
        sc=ax1.scatter(vec_objs[:, 0], vec_objs[:, 1], vec_objs[:, 2], c=all_labels[:], s=s, alpha=1.0, marker='o')
        ax1.set_xlabel("objective 1")
        ax1.set_ylabel("objective 2")
        ax1.set_zlabel("objective 2")
        plt.colorbar(sc)
        ax1.view_init(45, 215)
    
    if len(vec_objs[0])<=3:
        plt.title('{} - Pareto Front'.format(args.env))
        if not os.path.exists('Figures/{}/{}'.format(name, args.seed)): 
            os.makedirs('Figures/{}/{}'.format(name, args.seed))
        plt.savefig('Figures/{}/{}/{}.png'.format(name, args.seed ,ext))


class MOOfflineEnv(Wrapper):
    def __init__(self, env_name, safe_obj_list=None, dataset_class='d4rl', num_objective=2, seed=None):
        self.num_objective = num_objective
        self.env_name = env_name
        self.dataset_class = dataset_class
        self.safe_obj_list = np.array([False] * num_objective if safe_obj_list is None else safe_obj_list)
        if dataset_class == 'd4rl':
            import gym
            from gym.spaces.box import Box
            env = gym.make(f"{self.env_name.lower()}-medium-v2")
        elif dataset_class == 'd4morl' or dataset_class == 'cmo':
            import gym
            from gym.spaces.box import Box
            env = gym.make(env_name)
            env._max_episode_steps += 1  #TODO #To solve some bugs in moenvs
        elif dataset_class == 'safe':
            import gymnasium as gym
            from gymnasium.spaces.box import Box
            env = gymnasium.make(env_name)
        env.reward_space = np.zeros((num_objective,))
        self._max_episode_steps = env._max_episode_steps
        super(MOOfflineEnv, self).__init__(env)

    def cal_reward_in_d4rl(self, action):
        xposbefore = self.env.sim.data.qpos[0]
        obs, reward, done, info = self.env.step(action)
        xposafter, height, ang = self.sim.data.qpos[0:3]
        alive_bonus = 1.0
        reward1 = (xposafter - xposbefore)/self.env.dt + alive_bonus
        reward2 = 4.0 - 1.0 * np.square(action).sum() + alive_bonus
        if self.num_objective == 1:
            reward = np.array([reward])
        else:
            reward = np.array([reward1, reward2])
        return obs, reward, done, info

        
    def step(self, action):
        timeout = None
        if self.dataset_class == 'd4rl':
            obs, reward, done, info = self.cal_reward_in_d4rl(action)
        elif self.dataset_class == 'd4morl':
            obs, reward, done, info = self.env.step(action)
        elif self.dataset_class == 'safe':
            obs, reward, done, timeout, info = self.env.step(action)
            reward = np.array([reward, info['cost']])
        elif self.dataset_class == 'cmo':
            obs, reward, done, info = self.env.step(action)
        info['obj'] = reward
        return (obs, reward, done, info) if timeout is None else (obs, reward, done, timeout, info)
    
    def get_dataset(self, dataset_type):
        if self.dataset_class == 'd4rl':
            import gym
            env = gym.make(f"{self.env_name.lower()}-{dataset_type}-v2")
            dataset = env.get_dataset()
            return self.d4rl2morl(dataset, env)
        elif self.dataset_class == 'd4morl':
            dataset_path = f"./data/d4morl/{self.env_name}_50000_{dataset_type}.pkl"                
            with open(dataset_path, 'rb') as f:
                dataset = pickle.load(f)
            if 'Hopper' in self.env.spec.id:
                for i in range(len(dataset)): 
                    dataset[i]['actions'] /= np.array([[2, 2 ,4]]) 
            return dataset
        elif self.dataset_class == 'safe':
            dataset = self.env.get_dataset()
            return self.safe2morl(dataset)
        elif self.dataset_class == 'cmo':
            dataset = []
            for dirname in os.listdir(f"./data/cmo"):
                path = os.path.join(f"./data/cmo", dirname, 'dataset.pkl')
                if self.env_name in dirname and os.path.exists(path) and not 'backup' in dirname:
                    with open(path, 'rb') as f:
                        dataset.extend(pickle.load(f))
            for traj in dataset:
                traj['raw_rewards'] = np.concatenate([traj['raw_rewards'], traj['costs'].reshape(-1, 1)], axis=1)
            if 'Hopper' in self.env.spec.id:
                for i in range(len(dataset)): 
                    dataset[i]['actions'] /= np.array([[2, 2 ,4]]) 
            return dataset

        
    def get_test_dataset(self, data_dir=None):
        if self.dataset_class == 'd4morl' or self.dataset_class == 'cmo':
            if data_dir is None:
                data_dir = f"./data/{self.dataset_class}/test"
            dataset_path = os.path.join(data_dir, f"{self.env_name}.pkl")
            with open(dataset_path, 'rb') as f:
                dataset = pickle.load(f)
        elif self.dataset_class == 'safe':
            data_dir = f"./data/safe/test" if data_dir is None else data_dir
            dataset_path = os.path.join(data_dir, f"{self.env_name}.pkl")
            with open(dataset_path, 'rb') as f:
                dataset = pickle.load(f)
            # from tianshou.data.utils.converter import h5py, from_hdf5
            # data_dir = f'./data/test_safe' if data_dir is None else data_dir
            # data_dir = os.path.join(data_dir, self.env_name)

            # if not os.path.exists(data_dir):
            #     return []

            # filenames = os.listdir(data_dir)
            # thresholds = [float(filename.split('_')[-1].split('.')[0]) for filename in filenames]
            # filenames = [filenames[idx] for idx in np.argsort(thresholds)]
            # dataset = []
            # for filename in filenames:
            #     threshold = float(filename.split('_')[-1].split('.')[0])
            #     with h5py.File(os.path.join(data_dir, filename), "r") as f:
            #         test_data = from_hdf5(f)
            #     test_data = self.safe2morl(test_data, self.env)
            #     dataset.append({'label': threshold, 'demo': test_data})
            
        return dataset

    def d4rl2morl(self, dataset, env, max_episode_len=1000):
        def calculate_reward_vector(info_qpos, action, timeouts, terminals ): 
            N = action.shape[0]
            reward1, reward2 = [], []
            episode_step = 0
            for i in range(N):
                alive_bonus = 1.0
                xposafter, xposbefore = info_qpos[min(i+1, N-1), 0], info_qpos[i, 0]
                r1 = (xposafter - xposbefore)/env.dt + alive_bonus
                r2 = 4.0 - 1.0 * np.square(action[i]).sum() + alive_bonus

                if bool(terminals[i]) or timeouts[i]:
                    r1 = reward1[-1]  # Use the previous reward to approximate the current reward in the last step
                    episode_step = 0

                reward1.append(r1)
                reward2.append(r2)
                episode_step += 1

            reward1, reward2 = np.array(reward1), np.array(reward2)
            return np.stack([reward1, reward2], axis=1)

        reward = calculate_reward_vector(dataset['infos/qpos'], dataset['actions'], 
                                                    dataset['timeouts'], dataset['terminals'])
        assert len(reward)==len(dataset['observations'])
        step = 0
        ret = np.zeros((reward.shape[-1]))
        observations, actions, next_observations, raw_rewards, terminals = [], [], [], [], []
        d4morl_dataset = []
        for i, rw in enumerate(reward):
            step += 1
            ret += rw
            observations.append(dataset['observations'][i])
            actions.append(dataset['actions'][i])
            next_observations.append(dataset['next_observations'][i])
            raw_rewards.append(rw)

            if dataset['terminals'][i] or dataset['timeouts'][i] or step==max_episode_len:
                terminals.append(True)
                preference = (ret / np.linalg.norm(ret, ord=1)).reshape(1, -1).repeat(len(raw_rewards), 0)
                d4morl_dataset.append({
                    'observations': np.array(observations),
                    'actions': np.array(actions),
                    'next_observations': np.array(next_observations),
                    'raw_rewards': np.array(raw_rewards),
                    'terminals': np.array(terminals),
                    'preference': preference,
                }) 
                step = 0
                ret = np.zeros((reward.shape[-1]))
                observations, actions, next_observations, raw_rewards, terminals = [], [], [], [], []
            else:
                terminals.append(False)
        return d4morl_dataset
    
    def safe2morl(self, dataset, env=None):
        reward = np.stack([dataset['rewards'], dataset['costs']], axis=1)
        step = 0
        ret = np.zeros((reward.shape[-1]))
        observations, actions, next_observations, raw_rewards, terminals = [], [], [], [], []
        d4morl_dataset = []
        for i, rw in enumerate(reward):
            step += 1
            ret += rw
            observations.append(dataset['observations'][i])
            actions.append(dataset['actions'][i])
            next_observations.append(dataset['next_observations'][i])
            raw_rewards.append(rw)

            if dataset['terminals'][i] or dataset['timeouts'][i]:
                terminals.append(True)
                preference = (ret / np.linalg.norm(ret, ord=1)).reshape(1, -1).repeat(len(raw_rewards), 0)
                d4morl_dataset.append({
                    'observations': np.array(observations),
                    'actions': np.array(actions),
                    'next_observations': np.array(next_observations),
                    'raw_rewards': np.array(raw_rewards),
                    'terminals': np.array(terminals),
                    'preference': preference,
                }) 
                step = 0
                ret = np.zeros((reward.shape[-1]))
                observations, actions, next_observations, raw_rewards, terminals = [], [], [], [], []
            else:
                terminals.append(False)
        return d4morl_dataset

    def get_normalized_score(self, tot_rewards):
        if self.num_objective == 1:
            return np.array([self.env.get_normalized_score(tot_rewards[0])])
        elif self.dataset_class == 'd4rl' or self.dataset_class == 'd4morl':
            return tot_rewards
        elif self.dataset_class == 'safe':
            norm_r, norm_c = self.env.get_normalized_score(tot_rewards[0], tot_rewards[1])
            return np.array([norm_r, norm_c])
        elif self.dataset_class == 'cmo':
            return self.env.get_normalized_score(tot_rewards)
        
    def reset(self, **kwargs):
        if 'seed' in kwargs:
            np.random.seed(kwargs['seed'])
        return self.env.reset(**kwargs)
        

def check_dominated(obj_batch, obj, tolerance=0):
    return (np.logical_and((obj_batch * (1-tolerance) >= obj).all(axis=1), (obj_batch * (1-tolerance) > obj).any(axis=1))).any()

# return sorted indices of nondominated objs
def undominated_indices(obj_batch_input, tolerance=0):
    obj_batch = np.array(obj_batch_input)
    sorted_indices = np.argsort(obj_batch.T[0])
    indices = []
    for idx in sorted_indices:
        if (obj_batch[idx] >= 0).all() and not check_dominated(obj_batch, obj_batch[idx], tolerance):
            indices.append(idx)
    return indices



def eval_one_episode(test_vec_env, agent, evalPreference, args, seed=0):
    num_eval_env = len(evalPreference)
    evalPreference_input = torch.FloatTensor(evalPreference).to(args.device)
    # reset the environment
    if isinstance(test_vec_env, gymnasium.Env):
        env_seeds = list(range(seed*args.num_eval_env, (seed+1)*args.num_eval_env))
        state, _ = test_vec_env.reset(seed=env_seeds)
    else:
        state = test_vec_env.reset()
        
    is_terminated = np.zeros((num_eval_env,))
    tot_rewards, cnt = 0, 0
    # interact with the environment
    while not np.all(is_terminated) and args.max_episode_len>cnt:
        if hasattr(agent, 'deterministic'):
            action = agent(state, evalPreference_input, deterministic = True)
        else:
            action = agent(state, evalPreference_input)
        env_out = test_vec_env.step(action)
        next_state, reward, terminal = env_out[0], env_out[1], env_out[2]
        tot_rewards += reward * (1-is_terminated.reshape(-1, 1))
        is_terminated = np.logical_or(terminal, is_terminated)
        state = next_state
        cnt += 1
    return tot_rewards

def log_morl_results(args, all_objs, ts, prefix=''):
    # Compute hypervolume and sparsity
    recovered_objs = np.array([obj for _, obj in all_objs])
    all_prefs = np.array([pref for pref, _ in all_objs])
    all_scalar_rets = np.array([np.dot(evalPreference, objs) for evalPreference, objs in all_objs])
    perf_ind = get_performance_indicator("hv", ref_point = np.zeros((recovered_objs.shape[1]))) 
    hv = perf_ind.do(-recovered_objs) 
    s = compute_sparsity(-recovered_objs)

    if args is not None:
        print(f"Prefix: {prefix}, EU: {np.mean(all_scalar_rets)}, hv: {hv}, sparsity: {s}")
        args.writer.log({f"{prefix}/ret": np.mean(all_scalar_rets), "train_iter": ts})
        args.writer.log({f"{prefix}/hv": hv, "train_iter": ts})
        args.writer.log({f"{prefix}/sparsity": s, "train_iter": ts})

        table_kwargs = {f"pref{i+1}":all_prefs[:, i] for i in range(all_prefs.shape[-1])}
        table_kwargs.update({f"obj{i+1}":recovered_objs[:, i] for i in range(all_prefs.shape[-1])})
        args.writer.log_table(f"{prefix}/Pareto front", **table_kwargs)

    return hv, s



# Evaluate agent druing training 
def eval_agent(test_vec_env, test_env, agent, test_datasets, w_batch_test, args, ts, eval_episodes=1):
    num_eval_env = args.num_eval_env
    test_vec_env = gym.vector.AsyncVectorEnv(test_vec_env) if isinstance(test_env.env, gym.Env) else gymnasium.vector.AsyncVectorEnv(test_vec_env)
    time_start = time.time()

    all_eval_data = {'evalPreference': [], 'ret': []}
    all_objs = []
    
    eval_pref_batch = w_batch_test if args.dataset=='safe' or args.dataset=='cmo' else [np.array(x['preference']) for x in test_datasets] 
    for prefer_id, evalPreference in enumerate(eval_pref_batch): 
        iter_start = time.time()
        evalPreference = L1_norm(evalPreference)
        objs, scalar_rets = [], []
        evalprefs = np.array([evalPreference] * num_eval_env)
        if hasattr(agent, 'eval_one_episode'):
            unnormalized_tot_rewards = agent.eval_one_episode(prefs=evalprefs)  #PEDA
        else:
            unnormalized_tot_rewards = eval_one_episode(test_vec_env, agent, evalprefs, args, seed=prefer_id)
        if hasattr(test_env, 'set_target_cost'):
            test_env.set_target_cost(test_env.max_episode_cost) 
        tot_rewards = np.array([test_env.get_normalized_score(ret) for ret in unnormalized_tot_rewards])
        objs.extend(tot_rewards)
        scalar_rets.extend([np.dot(evalPreference,ret) for ret in tot_rewards ])
        for i in range(num_eval_env):
            all_eval_data['evalPreference'].append(evalprefs[i])
            all_eval_data['ret'].append(tot_rewards[i])
        
        scalar_rets = np.mean(scalar_rets[:eval_episodes]) 
        objs_mean, objs_std = np.mean(objs[:eval_episodes], axis=0), np.std(objs[:eval_episodes], axis=0)
        all_objs.append((evalPreference, objs_mean))
        print(f'preference: {evalPreference}, ret_vec: {objs_mean} (std: {objs_std}), scalar_rets: {scalar_rets}, time cost: {time.time()-iter_start}', flush=True)
        
    hv, s = log_morl_results(args, all_objs, ts, prefix='eval')
    if len(evalPreference)<=3:
        plot_pareto_front(args, all_objs, args.name, ext=f'final_eval')

    save_path = "Exps/{}/{}/".format(args.name, args.seed)
    if not os.path.exists(save_path): 
        os.makedirs(save_path)
    torch.save(all_eval_data, "{}{}".format(save_path, f'eval_data{args.record_label}.pkl'))

    test_vec_env.close()
    print("eval time:", time.time()-time_start)

    return hv, s, all_objs


def plot_td_pi_reward(save_path, save_name, w_batch, norm_td_reward, log_pi_reward, ret):
    x_dim = 0
    idx = torch.argsort(w_batch[:, x_dim])
    for i in range(len(idx)):
        print(f"pref: {w_batch[idx[i]].cpu().detach().numpy()}, td: {norm_td_reward[idx[i]]}, log_a: {log_pi_reward[idx[i]]}, ret: {ret[idx[i]]}")

    if not os.path.exists(save_path): 
        os.makedirs(save_path)
    plt.clf()
    plt.figure(figsize=(20,5))
    labels = ['ret', 'td_reward', 'log_pi_reward']
    for i, measure in enumerate([ret, norm_td_reward, log_pi_reward]):
        plt.subplot(131+i)
        if w_batch.shape[-1]==2:
            plt.plot([w_batch[idx[i]][x_dim].item() for i in range(len(idx))], [measure[idx[i]].item() for i in range(len(idx))], label=labels[i])
            plt.axvline(w_batch.mean(dim=0)[x_dim].item())
        elif w_batch.shape[-1]==3:
            sc = plt.scatter(w_batch[:, 0].cpu().detach().numpy(), w_batch[:, 1].cpu().detach().numpy(), c=measure.cpu().detach().numpy(), label=labels[i])
            plt.axvline(w_batch.mean(dim=0)[0].item())
            plt.axhline(w_batch.mean(dim=0)[1].item())
            plt.title(f"mean: {w_batch.mean(dim=0).cpu().detach().numpy()}")
            plt.colorbar(sc)
        else:
            return
        plt.legend()
    plt.savefig(os.path.join(save_path, save_name))


class PrefDistAdaptation():
    def __init__(self, buffer, train_dataset, agent, args):
        # calculate the mean of td error in the training dataset
        batch = buffer.sample(1024) 
        state, action, next_state, reward, not_done, w_obj_batch = batch
        if args.dataset!='cmo':
            prefs = np.array([traj['preference'][0] for traj in train_dataset])
            prior_pref = [np.mean(prefs, axis=0), np.std(prefs, axis=0)]
        else:
            prior_pref = [np.array([0.5, 0.167, 0.333]), np.array([0.25, 0.25, 0.25])]
        print(f"prior_pref: {prior_pref}")
        with torch.no_grad():
            if args.adpt_td_type=='qf':
                td_mean = agent.get_critic_tderror(state, action, reward, next_state, w_obj_batch, not_done).abs().mean(0, keepdim=True) 
            elif args.adpt_td_type=='vf':
                td_mean = agent.get_vf_tderror(state, action, reward, next_state, w_obj_batch, not_done).abs().mean(0, keepdim=True)

        batch = buffer.sample(1024) 
        state, action, next_state, reward, not_done, w_obj_batch = batch
        log_pi_reward = agent.get_log_action_prob(state, action, w_obj_batch)
        print(f"log_pi_reward mean: {log_pi_reward.mean()}")
        self.td_mean, self.prior_pref = td_mean, prior_pref

        self.init_mean = torch.tensor(prior_pref[0][args.adpt_fix_dim1:], requires_grad=True).float().to(args.device)
        self.init_logstd = torch.tensor(np.log(prior_pref[1][args.adpt_fix_dim1:]+1e-8), requires_grad=True).float().to(args.device)    
        self.init_dist = Normal(self.init_mean, self.init_logstd.exp())

    def adapt(self, args, test_env, agent, demos, real_preference, real_cost_limit, eval_episodes): 
        mean, logstd = self.init_mean.clone().detach(), self.init_logstd.clone().detach() + np.log(1.0) # Promote the exploration 
        mean.requires_grad_(True)
        logstd.requires_grad_(True)
        adpt_optim = Adam([mean, logstd], lr=args.adpt_lr, betas=(0.9, 0.999))  
        replay_buffer_eval = ptan.experience.ReplayBuffer(args)
        replay_buffer_eval.load_from_dataset(test_env, demos)
        adpt_batch_demo = min(replay_buffer_eval.size, args.adpt_batch_demo)
        candidate_ind = np.random.choice(np.arange(0, replay_buffer_eval.size), size=adpt_batch_demo, replace=False)  # use only demos with number adpt_batch_demo 
        replay_buffer_eval.fix_sample_ind(candidate_ind)
        constraint_obj_idx = np.where(test_env.safe_obj_list)[0]
        unconstraint_obj_idx = np.where(1-test_env.safe_obj_list)[0]

        # adaptation phase
        for step in range(args.adpt_steps):
            batch = replay_buffer_eval.sample(adpt_batch_demo) 
            state, action, next_state, reward, not_done, _ = batch
            state = torch.repeat_interleave(state, repeats=args.adpt_batch_pref, dim=0)
            reward = torch.repeat_interleave(reward, repeats=args.adpt_batch_pref, dim=0)
            next_state = torch.repeat_interleave(next_state, repeats=args.adpt_batch_pref, dim=0)
            action = torch.repeat_interleave(action, repeats=args.adpt_batch_pref, dim=0)
            not_done = torch.repeat_interleave(not_done, repeats=args.adpt_batch_pref, dim=0)
            
            dist = Normal(mean, logstd.exp())
            #z_batch = reject_sample(dist, args.adpt_batch_pref)
            z_batch = dist.rsample((args.adpt_batch_pref, ))
            log_probs = dist.log_prob(z_batch.detach()).sum(dim=1)

            if not args.adpt_fix_dim1:
                w_batch = z_batch
            else:
                pref_dim1 = torch.ones((args.adpt_batch_pref, 1), device=args.device).float()*self.prior_pref[0][0]
                w_batch = torch.cat([pref_dim1, z_batch], dim=1)

            w_batch = L1_norm(torch.clip(w_batch, min=0.0, max=1.0))
            w_batch_repeat = torch.tile(w_batch, (adpt_batch_demo, 1))

            with torch.no_grad():
                if args.adpt_td_type=='qf':
                    td_reward = -agent.get_critic_tderror(state, action, reward, next_state, w_batch_repeat, not_done) 
                elif args.adpt_td_type=='vf':
                    td_reward = -agent.get_vf_tderror(state, action, reward, next_state, w_batch_repeat, not_done)
                else:
                    assert False, 'No such type of td error'

                td_reward = td_reward.reshape(adpt_batch_demo, args.adpt_batch_pref, -1).mean(dim=0) 
                norm_td_reward = (td_reward / self.td_mean).mean(dim=1)  #normalize each dim of td_error 
                
                log_pi_reward = agent.get_log_action_prob(state, action, w_batch_repeat).clip(min=-1000)
                log_pi_reward = log_pi_reward.reshape(adpt_batch_demo, args.adpt_batch_pref).mean(dim=0)
                ret = args.adpt_td_weight * norm_td_reward + log_pi_reward + \
                    args.adpt_prior_weight * self.init_dist.log_prob(z_batch).sum(dim=1).detach()/adpt_batch_demo 
            
            loss = -torch.mean(log_probs * (ret-ret.mean()).detach()) - args.adpt_entropy_weight * dist.entropy().sum()/adpt_batch_demo + (mean.sum()-1.0)**2 

            adpt_optim.zero_grad()
            loss.backward()
            adpt_optim.step()
            mean.data, logstd.data = torch.clip(mean.data, min=0.0, max=1.0), torch.clip(logstd.data, -3, 0)

            if (step)%100==0:
                plot_td_pi_reward(f"test_fig/{args.name}/{args.algo}/pref_{real_preference}_thres_{real_cost_limit}", f"ret_{(step)//50}.jpg", 
                                  w_batch, norm_td_reward, log_pi_reward, ret)
                print(f"ret: {ret.mean().item()}\n w: {w_batch.mean(dim=0).cpu().detach().numpy()}")
                print(f"mean: {mean.cpu().detach().numpy()}, std: {logstd.exp().cpu().detach().numpy()}\n")

        weight = scipy.stats.norm.pdf(scipy.stats.norm.ppf(q=args.cvar_alpha)) / args.cvar_alpha
        pref_sample = mean.clone()
        pref_sample[constraint_obj_idx] += weight*logstd[constraint_obj_idx].exp()
        pref_sample[unconstraint_obj_idx] *= torch.clip(1.0-weight*logstd[constraint_obj_idx].exp().sum()/mean[unconstraint_obj_idx].sum(), min=0.0)
        pref_sample = pref_sample.reshape(1, -1).repeat(eval_episodes, 1)   # another implementation: pref_sample = dist.sample((1000000,))
        if args.adpt_fix_dim1:
            pref_dim1 = torch.ones((len(pref_sample), 1), device=args.device).float()*self.prior_pref[0][0]
            pref_sample = torch.cat([pref_dim1, pref_sample], dim=1)
        pref_sample = L1_norm(pref_sample)
        predict_pref = pref_sample.mean(0).cpu().detach().numpy()
        predict_pref_mean, predict_pref_std = mean.cpu().detach().numpy(), logstd.exp().cpu().detach().numpy()
        evalprefs = pref_sample[:eval_episodes].cpu().detach().numpy()
        print(f'Adaptation preference distribution: mean {predict_pref_mean}, std {predict_pref_std}')
        print(f'Predicted preference: {predict_pref}')
        info = {'prefs':evalprefs, 'predict_pref': predict_pref}
        return agent, evalprefs, info

class FinetuneAdaptation():
    def __init__(self, buffer, train_dataset, agent, args):
        # calculate the mean of td error in the training dataset
        pass

    def adapt(self, args, test_env, agent, demos, real_preference, real_cost_limit, eval_episodes): 
        replay_buffer_eval = ptan.experience.ReplayBuffer(args)
        replay_buffer_eval.load_from_dataset(test_env, demos)
        adpt_batch_demo = min(replay_buffer_eval.size, args.adpt_batch_demo)
        candidate_ind = np.random.choice(np.arange(0, replay_buffer_eval.size), size=adpt_batch_demo, replace=False)  # use only demos with number adpt_batch_demo 
        replay_buffer_eval.fix_sample_ind(candidate_ind)
        agent_finetune = copy.deepcopy(agent)
        agent_finetune.actor_optimizer = torch.optim.Adam(agent_finetune.actor.parameters(), lr=args.finetune_lr)
        # finetune phase
        for step in range(args.finetune_step):
            agent_finetune.train_bc_policy(replay_buffer_eval, args.writer)
        evalprefs = np.zeros((eval_episodes, args.reward_size))
        return agent, evalprefs, {}
    
class PromptMODTAdaptation():
    def __init__(self, buffer, train_dataset, agent, args):
        # calculate the mean of td error in the training dataset
        pass

    def adapt(self, args, test_env, agent, demos, real_preference, real_cost_limit, eval_episodes): 
        replay_buffer_eval = ptan.experience.PromptTrajReplayBuffer(args, None, None)
        #replay_buffer_eval.load_from_dataset(test_env, demos)
        adpt_batch_demo = args.adpt_batch_demo
        prompt = replay_buffer_eval.get_traj_seg(demos[0], seg_len=adpt_batch_demo, maximize_seglen=True)
        prompt = tuple([x[0] for x in prompt])
        prompts = [prompt for _ in range(eval_episodes)]
        prefs = demos[0]['preference'][0].reshape(1, -1).repeat(eval_episodes, 0)
        return agent, prefs, {'prefs': prefs, 'prompts': prompts}


# Evaluate agent adaptive to expert demo  
def eval_agent_adaptation(adaptation_method, test_vec_env, test_env, agent, args, ts, test_datasets, eval_episodes=1):
    test_vec_env = gym.vector.AsyncVectorEnv(test_vec_env) if isinstance(test_env.env, gym.Env) else gymnasium.vector.AsyncVectorEnv(test_vec_env)
    time_start = time.time()
    reward_size = args.reward_size
    constraint_obj_idx = np.where(test_env.safe_obj_list)[0]
    unconstraint_obj_idx = np.where(1-test_env.safe_obj_list)[0]

    all_objs = []
    all_performance = []

    for data_id, data in enumerate(test_datasets):
        iter_start = time.time()
        real_preference = np.array(data['preference']) #ground truth preference and threshold
        real_cost_limit = np.array(data['cost_limit']) #ground truth preference and threshold
        demos = data['demo'] 
        print(f'Real preference: {real_preference}, Real cost limit: {real_cost_limit} (-1 means unlimited)')
        adapted_agent, adapted_prefs, agent_input = adaptation_method.adapt(args, test_env, agent, demos, real_preference, real_cost_limit, eval_episodes)

        objs = []
        if hasattr(test_env, 'set_target_cost'):
            test_env.set_target_cost(real_cost_limit[-1])  # set threshold for safe RL
        if hasattr(adapted_agent, 'eval_one_episode'):
            unnormalized_tot_rewards = adapted_agent.eval_one_episode(**agent_input)
        else:
            unnormalized_tot_rewards = eval_one_episode(test_vec_env, adapted_agent, adapted_prefs, args, seed=data_id)
        tot_rewards = np.array([test_env.get_normalized_score(ret) for ret in unnormalized_tot_rewards])
        objs.extend(tot_rewards)
        objs = np.array(objs)
        objs_mean, objs_std = np.mean(objs[:eval_episodes], axis=0), np.std(objs[:eval_episodes], axis=0)
        all_objs.append((real_preference, real_cost_limit, objs_mean))
        print(f'Unnormalized returns: {unnormalized_tot_rewards}')
        print(f'Return: {objs_mean} (std: {objs_std}), time cost: {time.time()-iter_start}\n', flush=True)
        print(f'--------------------------------------------------------------------------\n')
        
        predict_pref = agent_input['predict_pref'] if 'predict_pref' in agent_input else adapted_prefs.mean(0)
        utility = np.dot(objs_mean[unconstraint_obj_idx], real_preference[unconstraint_obj_idx])
        predict_unconstraint_pref = L1_norm(adapted_prefs.mean(0)[unconstraint_obj_idx])
        real_uncontraint_pref = L1_norm(real_preference[unconstraint_obj_idx])
        pref_L1_loss = np.mean(np.abs(predict_unconstraint_pref-real_uncontraint_pref))

        cost_limit_dict = {f"adpt/cost_limit_{i}": real_cost_limit[i] for i in range(len(real_cost_limit))}
        real_pref_dict = {f"adpt/real_pref_{i}": real_preference[i] for i in range(len(real_preference))}
        predict_pref_dict = {f"adpt/predict_pref_{i}": predict_pref[i] for i in range(len(predict_pref))}
        predict_unconstraint_pref_dict = {f"adpt/predict_unconstraint_pref_{i}": predict_unconstraint_pref[i] for i in range(len(predict_unconstraint_pref))}
        args.writer.log({"adpt/demo_id": data_id, "adpt/utility": utility, "adpt/pref_L1_loss": pref_L1_loss, **cost_limit_dict, **real_pref_dict, **predict_pref_dict, **predict_unconstraint_pref_dict,
                         "adpt/max_cost": np.max(objs_mean[constraint_obj_idx]) if len(constraint_obj_idx) else 0})
        all_performance.append({'normalized_cost': objs_mean[constraint_obj_idx], 'utility': utility, 'pref_L1_loss': pref_L1_loss, 'predict_unconstraint_pref': predict_unconstraint_pref,
                                 'cost_limit': real_cost_limit, 'real_preference': real_preference, 'evalPreference': adapted_prefs[:eval_episodes], 'objs': objs[:eval_episodes]})
        for idx in constraint_obj_idx:
            print(f'Objective {idx}, Real constraint threshold: {real_cost_limit[idx]}, Normalized_cost: {objs_mean[idx]}')
        print(f'Max cost: {np.max(objs_mean[constraint_obj_idx]) if len(constraint_obj_idx) else 0}, utility: {utility}')
        print(f'Real preference: {real_uncontraint_pref}, predict unconstrained preference: {predict_unconstraint_pref}, pref_L1_loss: {pref_L1_loss}')
        print(f'--------------------------------------------------------------------------\n')
        
    all_hv, all_sp = [], []
    test_vec_env.close()
    all_cost_limit = set([tuple(x['cost_limit']) for x in all_performance])
    for limit in all_cost_limit:
        pref_list, obj_list = [], []
        for x in all_performance:
            if limit==tuple(x['cost_limit']):
                predict_unconstraint_pref = L1_norm(x['evalPreference'][:, unconstraint_obj_idx])
                pref_list.extend(predict_unconstraint_pref), obj_list.extend(x['objs'][:, unconstraint_obj_idx])
        hv, sp = log_morl_results(args, list(zip(pref_list, obj_list)), ts, prefix=f'adpt/cost_limit_{limit}') 
        all_hv.append(hv), all_sp.append(sp)
        if len(pref_list[0])<=3: 
            plot_pareto_front(args, list(zip(pref_list, obj_list)), args.name, ext=f'final_adpt_eval_thres_{limit}')
        

    save_path = "Exps/{}/{}/".format(args.name, args.seed)
    if not os.path.exists(save_path): 
        os.makedirs(save_path)
    torch.save(all_performance, "{}{}".format(save_path, f'adpt_eval_data{args.record_label}.pkl'))
    
    mean_utility = np.mean([x['utility'] for x in all_performance])
    mean_pref_L1_loss = np.mean([x['pref_L1_loss'] for x in all_performance])
    max_cost = np.max([np.max(x['normalized_cost']) for x in all_performance]) if len(constraint_obj_idx) else 0
    print(f"mean_utility: {mean_utility}, mean_pref_L1_loss: {mean_pref_L1_loss}, adpt/max_cost: {max_cost}")
    print("adaptation time:", time.time()-time_start)
    args.writer.log({"adpt/mean_utility": mean_utility, "adpt/mean_pref_L1_loss": mean_pref_L1_loss, "adpt/max_all_cost": max_cost, "adpt/mean_hv": np.mean(all_hv), "adpt/mean_sp": np.mean(all_sp)})

    plt.clf(), plt.figure(figsize=(12,5)), plt.subplot(121)
    plt.scatter([x['cost_limit'][1] for x in all_performance], [np.max(x['normalized_cost']) if len(constraint_obj_idx) else 0 for x in all_performance], label=f'cost, max:{max_cost}')
    plt.scatter([x['cost_limit'][1] for x in all_performance], [x['utility'] for x in all_performance], label=f'utility, mean: {mean_utility}')
    plt.scatter([x['cost_limit'][1] for x in all_performance], [x['evalPreference'].mean(0)[0] for x in all_performance], label=f'predicted_pref_dim1')
    plt.plot([0, 1.0], [1.0, 1.0])
    plt.legend(), plt.subplot(122)
    plt.scatter([x['real_preference'][0] for x in all_performance], [x['pref_L1_loss'] for x in all_performance], label=f'pref_L1_loss, mean: {mean_pref_L1_loss}')
    plt.legend(), plt.savefig(save_path+f"adpt_performance.jpg")
    return np.mean(all_hv), np.mean(all_sp), all_objs



